{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import py_code.input_data as input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#下载下来的数据集被分三个子集：\n",
    "#5.5W行的训练数据集（mnist.train），\n",
    "#5千行的验证数据集（mnist.validation)\n",
    "#1W行的测试数据集（mnist.test）。\n",
    "#因为每张图片为28x28的黑白图片，所以每行为784维的向量。\n",
    "trainimg   = mnist.train.images                               \n",
    "trainlabel = mnist.train.labels\n",
    "testimg    = mnist.test.images\n",
    "testlabel  = mnist.test.labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784)\n",
      "(55000, 10)\n",
      "(10000, 784)\n",
      "(10000, 10)\n",
      "<bound method DataSet.next_batch of <tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x0000022229C15CF8>>\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "print (trainimg.shape)\n",
    "print (trainlabel.shape)\n",
    "print (testimg.shape)\n",
    "print (testlabel.shape)\n",
    "print(mnist.train.next_batch)\n",
    "#print (trainimg)\n",
    "print (trainlabel[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(shape=(None, 784), dtype=tf.float32)\n",
    "y = tf.placeholder(shape=(None, 10), dtype=tf.float32)\n",
    "\n",
    "W = tf.Variable(tf.zeros([784,10]))\n",
    "b = tf.Variable(tf.zeros([10]))\n",
    "\n",
    "#回归模型  w*x+b\n",
    "actv=tf.nn.softmax(tf.matmul(x,W)+b)\n",
    "\n",
    "cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))\n",
    "\n",
    "learning_rate=0.01\n",
    "\n",
    "optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_epochs=50\n",
    "#批尺寸\n",
    "batch_size=100\n",
    "# 训练迭代次数\n",
    "display_step=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred=tf.equal(tf.argmax(actv,1),tf.argmax(y,1))\n",
    "#正确率\n",
    "accr=tf.reduce_mean(tf.cast(pred,\"float\"))\n",
    "#初始化\n",
    "init=tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000/050 cost: 1.176283828 train_acc: 0.840 test_acc: 0.853\n",
      "Epoch: 005/050 cost: 0.440940045 train_acc: 0.860 test_acc: 0.895\n",
      "Epoch: 010/050 cost: 0.383335689 train_acc: 0.920 test_acc: 0.905\n",
      "Epoch: 015/050 cost: 0.357283105 train_acc: 0.910 test_acc: 0.908\n",
      "Epoch: 020/050 cost: 0.341510462 train_acc: 0.930 test_acc: 0.912\n",
      "Epoch: 025/050 cost: 0.330527925 train_acc: 0.870 test_acc: 0.914\n",
      "Epoch: 030/050 cost: 0.322377158 train_acc: 0.900 test_acc: 0.916\n",
      "Epoch: 035/050 cost: 0.315936385 train_acc: 0.900 test_acc: 0.917\n",
      "Epoch: 040/050 cost: 0.310713485 train_acc: 0.940 test_acc: 0.918\n",
      "Epoch: 045/050 cost: 0.306388705 train_acc: 0.930 test_acc: 0.919\n",
      "DONE\n"
     ]
    }
   ],
   "source": [
    "sess=tf.Session()\n",
    "sess.run(init)\n",
    "\n",
    "for epoch in range(training_epochs):\n",
    "    avg_cost=0.\n",
    "    #55000/100\n",
    "    num_batch=int(mnist.train.num_examples/batch_size)\n",
    "    for i in range(num_batch):\n",
    "        #获取数据集 next_batch获取下一批的数据\n",
    "        batch_xs,batch_ys=mnist.train.next_batch(batch_size)\n",
    "        #模型训练\n",
    "        sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})\n",
    "        feeds={x:batch_xs,y:batch_ys}\n",
    "        avg_cost+=sess.run(cost,feed_dict=feeds)/num_batch\n",
    "    #满足5次的一个迭代\n",
    "    if epoch % display_step == 0:\n",
    "        feeds_train = {x: batch_xs, y: batch_ys}\n",
    "        feeds_test = {x: mnist.test.images, y: mnist.test.labels}\n",
    "        train_acc = sess.run(accr, feed_dict=feeds_train)\n",
    "        test_acc = sess.run(accr, feed_dict=feeds_test)\n",
    "        print (\"Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f\" \n",
    "               % (epoch, training_epochs, avg_cost, train_acc, test_acc))\n",
    "print (\"DONE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
